-
Notifications
You must be signed in to change notification settings - Fork 330
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add the Segment Anything Model to KerasCV #1987
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
This is super exciting 🎊
Just a few comments as I took a quick look
|
||
|
||
@keras.utils.register_keras_serializable(package="keras_cv") | ||
class MLPBlock(keras.layers.Layer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems a bit heavy to make this a class since we can just make this a pair of dense layers in a sequential wherever it's used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is used a few times in the mask decoder actually. So, just inline the dense layers would just duplicate a lot of code. Is there any side effect of having this? If not, I'd prefer keeping it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it's re-used in many places then I am alright with it -- it looked to me like it was only used once or twice but I probably missed some uses
|
||
|
||
@keras.utils.register_keras_serializable(package="keras_cv") | ||
class SAMLayerNormalization(keras.layers.Layer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there no way to parameterize keras.layers.LayerNormalization
to achieve this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is keras.layers.LayerNormalization(epsilon=1e-6)
. I will push this in the next batch of commits.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should probably double check and don't take my word for it, but I'm not sure if the numerics are the same. keras.layers.LayerNormalization
is:
# Compute the batch normalization.
inv = 1 / ops.sqrt(variance + self.epsilon)
if scale is not None:
scale = ops.cast(scale, inputs.dtype)
inv = inv * scale
x = -mean * inv
if offset is not None:
offset = ops.cast(offset, inputs.dtype)
x = offset + x
outputs = inputs * ops.cast(inv, inputs.dtype) + ops.cast(
x, inputs.dtype
)
outputs = ops.cast(outputs, input_dtype)
# If some components of the shape got lost due to adjustments, fix that.
outputs = ops.reshape(outputs, ops.shape(inputs))
For SAM, they call it LayerNorm2d()
in the official implementation, but the official impl is taken directly from Detectron2 which has BatchNorm2D and LayerNorm in turn taken from ConvNeXt: https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119
Technically, the LayerNorm re-implementation in ConvNeXt, SAM and Detectron2 shouldn't be the same LayerNorm from PyTorch.
More in these issues:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for linking these issues @DavidLandup0, I too was wondering why this was reimplemented. After some testing, I can confirm that keras.layers.LayerNormalization(epsilon=1e-6)
is numerically equivalent to SAMLayerNormalization()
for segment anything. Here's the code I used to test:
import os
os.environ["KERAS_BACKEND"] = "torch"
import numpy as np
import torch
import keras_core as keras
from keras_cv.models.segmentation.segment_anything import sam_layers
sam_ln = sam_layers.SAMLayerNormalization()
ln = keras.layers.LayerNormalization(epsilon=1e-6)
sam_ln.build((1, 512, 512, 3))
ln.build((1, 512, 512, 3))
sam_ln.set_weights(ln.weights)
x_np = np.random.randint(0, 256, size=(1, 512, 512, 3), dtype=np.uint8)
x_np = x_np.astype(np.float32)
x = torch.tensor(x_np, requires_grad=True)
x_sam = torch.tensor(x_np, requires_grad=True)
x_out_sam = sam_ln(x_sam)
x_out = ln(x)
x_out_sam.backward(torch.ones_like(x_out_sam))
x_out.backward(torch.ones_like(x_out))
np.testing.assert_allclose(
x_out_sam.detach().numpy(),
x_out.detach().numpy(),
rtol=8e-5
)
np.testing.assert_allclose(
ln.weights[0].value.grad.detach().numpy(),
sam_ln.weights[2].value.grad.detach().numpy(),
rtol=6e-7
)
np.testing.assert_allclose(
ln.weights[1].value.grad.detach().numpy(),
sam_ln.weights[3].value.grad.detach().numpy()
)
np.testing.assert_allclose(
x_sam.grad.detach().numpy(),
x.grad.detach().numpy(),
atol=3e-7
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome, thanks for checking! That simplifies things a lot :D
image_pe, | ||
sparse_prompt_embeddings, | ||
dense_prompt_embeddings, | ||
multimask_output, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this have a default?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it should. I haven't set any sensible defaults yet. I will update all the layers with some default values that make sense.
@@ -0,0 +1,13 @@ | |||
# Copyright 2023 The KerasCV Authors |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should probably have a top-level SegmentAnything
model which takes an ImageEncoder
as a backbone and subclasses Task
. Then the high-level workflows can live on that model.
Then we can also include a preset which includes your ported weights!
Let's also include a reference to the paper and original implementation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The image encoder for SAM is a near 1:1 from Detectron2's ViTDet.
IMO it makes sense to have ViTDet as a standalone class/network rather than a SAM encoder only.
That way we can train it from scratch, have it as a standalone object detection model, a backbone for SAM and reuse the same code across all of that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree @DavidLandup0, I will move the layer to keras_cv/layers
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should probably have a top-level
SegmentAnything
model which takes anImageEncoder
as a backbone and subclassesTask
. Then the high-level workflows can live on that model.Then we can also include a preset which includes your ported weights!
That's the plan. I will add a Task
model in the kera_cv/models/segmentation/segment_anything/sam.py
file. I am not yet sure how exactly the training step would be implemented with the Task
API, we could just raise a NotImplementedError
for now and write a train step as a follow-up.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update: I have moved the image encoder to a standalone backbone. Let me know if that looks good to you. Thanks for the suggestion @DavidLandup0!
(I will add a Task
model next)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great progress!
keras_cv/layers/detectron2_layers.py
Outdated
from keras_cv.models.segmentation.segment_anything.sam_layers import MLPBlock | ||
|
||
|
||
def get_rel_pos(query_size, key_size, rel_pos): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's move the helper functions to the bottom of the file
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Addressed in ac7f30e
x_out = ops.convert_to_numpy(attention_with_rel_pe(x)) | ||
self.assertEqual(x_out.shape, (1, 64, 64, 1280)) | ||
|
||
def test_windowed_transformer_encoder(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add a test for ViTDetPatchingAndEmbedding as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Addressed in ac7f30e
from keras_cv.utils.python_utils import classproperty | ||
|
||
|
||
@keras.utils.register_keras_serializable(package="keras_cv.models") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry that this changed while this PR is in-flight, but if you sync to master this should now be
from keras_cv.api_export import keras_cv_export
@keras_cv_export("keras_cv.models.ViTDetBackbone")
(Same for all new public API symbols)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Addressed in ac7f30e
|
||
def __init__( | ||
self, | ||
img_size=1024, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We've standardized on input_shape
for this in other backbones (accepting a tuple of (height, width, channels)
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried to add a input_shape
attribute but since I am not using the Functional
syntax for building the model, it doesn't allow me to add that attribute.
In case you didn't notice, I am using the call
method to specify the computations in the model instead of passing a symbolic input through each layer/operation in the __init__
method.
I noticed that it's just easier not to deal with symbolic inputs with Keras Core. One of the main reasons why Keras Core struggles with symbolic inputs is that it doesn't do shape inference. For example, in Keras, this works:
from tensorflow import keras
x = keras.Input([2, 3])
tf.shape(x)[0] * 10 # Note that even though the shape at axis 0 is None,
# TensorFlow returns a symbolic tensor making computations
# like these valid instead of throwing an exception.
but Keras Core fails, since x.shape[0]
is None
.
I think it should not be difficult to convert the implementation to fully use the Functional
syntax but we would have to manually check if the shapes we are getting are None
s or not. So, for example, we could do this in Keras Core:
import keras_core
from keras_core import ops
x = keras_core.Input([2, 3])
if x.shape[0] is not None:
x = ops.reshape(x, (x.shape[0] * 2, 3))
else:
x = ops.reshape(x, (None, 3))
or use some other operation.
Sorry for not highlighting these details properly beforehand! I will add a bunch of comments where I do something unintuitive so it's easier for future reviewers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, turns out the shape errors were not a problem here. So, got the model ported to the Functional syntax and it should now be consistent with other backbones. I have added include_rescaling
, and input_tensor
arguments along with the input_shape
arguments. I also tested that the weights port and can be saved/loaded in any backend. Let me know if this resolves the consistency issues.
|
||
|
||
@keras.utils.register_keras_serializable(package="keras_cv.models") | ||
class ViTDetBackbone(Backbone): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume that this backbone doesn't produce pyramid_level_outputs since it's a transformer architecture -- let's call this out in the docstring, and maybe even create an @Property for self.pyramid_level_inputs which throws a nice NotImplementedError
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Addressed in ac7f30e
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The backbone outputs the same shapes itself, but they do use a feature pyramid output: https://arxiv.org/pdf/2203.16527.pdf
TL:DR for the paper, the simple feature pyramid on the right turned out to be the most performant for them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @DavidLandup0, thanks for pointing out the paper! I didn't know the authors also proposed a FPN! I did look into the paper but don't know how the pyramid-level inputs would fit in the backbone here. Given this PR has already blown up a bit, I'd prefer to do this as a follow-up. Maybe you can take it up if you have time :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doing this as a follow-up sgtm
|
||
|
||
@keras.utils.register_keras_serializable(package="keras_cv") | ||
class MLP(keras.layers.Layer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this different from MLPBlock
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could unite those. The only difference is that the MLPBlock
has architecture embedding_dim -> mlp_dim -> embedding_dim
while MLP
has architecture input_dim -> [hidden_dim] * (num_layers - 1) -> output_dim
. Looks like a low-hanging fruit, will address in the next commit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
from keras_cv.tests.test_case import TestCase | ||
|
||
|
||
class TestSAM(TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: SAMTest
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
@tirthasheshpatel LMK when you're ready for another review on this 😄 |
They both behave exactly the same when moving_mean and moving_variance are None and epsilon is 1e-6
17a1cbb
to
05d1d27
Compare
- Use `keras_cv.export_api.keras_cv_export` instead of `keras.saving.register_keras_serializable`. - Add a `SerializableSequential` class to address the saving bug with the `Sequential` model. - Push the helper functions in `keras_cv/layers/detectron2_layers.py` to the bottom of the file. - Add the detectron2 layers to the `keras_cv/layers/__init__.py` file. - Add a test for the `ViTDetPatchingAndEmbedding` layer.
05d1d27
to
ac7f30e
Compare
Hi @ianstenbit, thank you very much for your reviews so far! Very helpful!
I have a |
keras_cv/layers/__init__.py
Outdated
@@ -17,6 +17,10 @@ | |||
from tensorflow.keras.layers import RandomWidth | |||
|
|||
from keras_cv.layers.augmenter import Augmenter | |||
from keras_cv.layers.detectron2_layers import AddPositionalEmbedding |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to put these under a detectron2
namespace?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
keras_cv/layers/__init__.py
Outdated
@@ -17,6 +17,10 @@ | |||
from tensorflow.keras.layers import RandomWidth | |||
|
|||
from keras_cv.layers.augmenter import Augmenter | |||
from keras_cv.layers.detectron2_layers import AddPositionalEmbedding |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this would be exported as part of the public API - we have PatchingAndEmbedding
which does patching with a Conv2D and then adds embeddings in this same form. Do we want to update that layer to use AddPositionalEmbedding
as well for conformity?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't adding a new layer in a pre-existing class invalidate the weights set for the ViT model? Also, since PatchingAndEmbedding
is still a TensorFlow Keras layer, I think, for the time being, it'd be easier to keep the two separate.
Although, I'd add a comment about this as a TODO so we don't forget to do it in the future. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
keras_cv/layers/detectron2_layers.py
Outdated
|
||
|
||
@keras_cv_export("keras_cv.layers.MultiHeadAttentionWithRelativePE") | ||
class MultiHeadAttentionWithRelativePE(keras.layers.Layer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps it would make sense to do an AddRelativePositionalEmbedding
class for consistency with the aforementioned AddPositionalEmbedding
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
keras_cv/layers/detectron2_layers.py
Outdated
) | ||
|
||
if self.use_rel_pos: | ||
attention_map = add_decomposed_rel_pos( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should probably be a private method as part of a layer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
keras_cv/layers/detectron2_layers.py
Outdated
if self.window_size > 0: | ||
H, W = x.shape[1], x.shape[2] | ||
|
||
x, HW_padded = window_partition(x, self.window_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you think about doing this as a layer instead of a method?
I.e. https://github.com/DavidLandup0/deepvision/blob/main/deepvision/layers/window_partitioning.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Instead of creating two classes, one for partitioning and one for unpartitioning, I handled both in a single class. Let me know if that looks good.
keras_cv/layers/detectron2_layers.py
Outdated
|
||
|
||
@keras_cv_export("keras_cv.layers.ViTDetPatchingAndEmbedding") | ||
class ViTDetPatchingAndEmbedding(keras.layers.Layer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the same as the ViT patching and embedding but without positional embedding.
I'm torn between being able to turn off PE in the default layer and adding that as a flag and having a new layer for this...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you referring to the PathingAndEmbedding
class for the ViT model in KerasCV? I addressed that here: #1987 (comment)
keras_cv/layers/detectron2_layers.py
Outdated
return config | ||
|
||
|
||
def get_rel_pos(query_size, key_size, rel_pos): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should probably be a private method or turned into a public layer.
I.e.: https://github.com/DavidLandup0/deepvision/blob/main/deepvision/layers/decomposed_relative_positional_embedding.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
keras_cv/layers/detectron2_layers.py
Outdated
return ops.take(rel_pos_resized, relative_coordinates, 0) | ||
|
||
|
||
def add_decomposed_rel_pos( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment as above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
# This only happens when the `build` method is called in the `__init__` | ||
# step. | ||
@keras_cv_export("keras_cv.layers.SerializableSequential") | ||
class SerializableSequential(keras.layers.Layer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this still an issue in Keras Core?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The bug has been addressed in Keras Core v0.1.5 but the latest TensorFlow Keras still has it. So, weights won't load in TensorFlow Keras until the bug is addressed in the next release.
We can either:
- Drop support temporarily for TensorFlow Keras just for this model with a note in the docs. With the new release of TF Keras, the bug should be fixed and we can remove the note.
- Keep the simple replication of the class until the bug is resolved in TensorFlow Keras.
I am leaning more towards option 2 but I don't have a strong opinion. What do you think @ianstenbit @DavidLandup0?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I ended up removing it since the legacy weights load in all backends in Keras Core and also in TF Keras. I think until some saving issues are addressed with the new .weights.h5
format, we should just use the legacy weights. Let me know what you both think!
keras_cv/models/__init__.py
Outdated
@@ -43,6 +43,18 @@ | |||
from keras_cv.models.backbones.densenet.densenet_backbone import ( | |||
DenseNetBackbone, | |||
) | |||
from keras_cv.models.backbones.detectron2.detectron2_aliases import ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Afaik, the backbone is basically the same as the official ViTDet, so there may not be a need to call it a SAM{name}Backbone
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good, thanks for looking into it!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
keras_cv/models/__init__.py
Outdated
@@ -166,5 +178,8 @@ | |||
YOLOV8Detector, | |||
) | |||
from keras_cv.models.segmentation import DeepLabV3Plus | |||
from keras_cv.models.segmentation import MaskDecoder |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
May be better as SAMMaskDecoder for clarity
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably left by accident?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have added this intentionally. This is used in the tests to verify that the model weights are loaded correctly and that the forward pass in all backends yields the same result.
"""Dictionary of preset names and configurations.""" | ||
return copy.deepcopy(backbone_presets) | ||
|
||
# @classproperty |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
stray comment?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This method loads the presets with weights. I will uncomment it later once the model layers are finalized and the final weights are uploaded.
|
||
|
||
@keras_cv_export("keras_cv.layers.MLP") | ||
class MLP(keras.layers.Layer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a public class here - it should probably be a private subclass, especially since there was an MLP with the same name in a layer related to this, iirc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. I have removed the export.
@@ -0,0 +1,230 @@ | |||
# Copyright 2023 The KerasCV Authors |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd probably put this layer under sam layers
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
|
||
@keras_cv_export("keras_cv.models.MaskDecoder") | ||
class MaskDecoder(keras.models.Model): | ||
"""Mask decoder for the segment anything model. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: "Segment Anything (SAM)"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
|
||
|
||
@keras_cv_export("keras_cv.models.MaskDecoder") | ||
class MaskDecoder(keras.models.Model): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As mentioned before, to avoid confusion, probably best if this is called SAMMaskDecoder or something along those lines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
network. Defaults to "gelu". | ||
|
||
References: | ||
- [Segment Anything](https://arxiv.org/abs/2304.02643) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We probably want a code reference as well
I think the PR is almost ready for some thorough reviews except for a few TODOs:
Let me know if you have any other major points @DavidLandup0 @ianstenbit. And thanks for the reviews so far, super helpful! |
An update on this: the legacy weights |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The only thing left that I see is adding a preset for the pre-trained version of SAM.
Thanks for your great work!
|
||
|
||
@keras.utils.register_keras_serializable(package="keras_cv.models") | ||
class ViTDetBackbone(Backbone): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doing this as a follow-up sgtm
/gcbrun |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is awesome -- thanks Tirth!
Just one little fix to make GCBRun happy
/gcbrun |
Thanks, @DavidLandup0 @ianstenbit for your reviews! This was fun to work on. Excited to have this in KerasCV! The next steps are to add some guides to use and train the model. It would also be nice to have some benchmarks. On it now! But I will also create a tracking issue in case the community wants to take over some of these tasks. |
Thank you Tirth for your outstanding work on this -- we really appreciate it! I think our long-term goal should be to add support for text prompts. There are some community projects out there which demonstrate the feasibility of this, and I think it would be a great step for us. But I 100% agree that some guides and training are the right place to start! |
* Start adding components for the segment anything model * SAMLayerNormalization -> keras.layers.LayerNormalization They both behave exactly the same when moving_mean and moving_variance are None and epsilon is 1e-6 * Move the image encoder to detectron2 backbone and fix for tf.keras backend * Address review comments and address saving bug - Use `keras_cv.export_api.keras_cv_export` instead of `keras.saving.register_keras_serializable`. - Add a `SerializableSequential` class to address the saving bug with the `Sequential` model. - Push the helper functions in `keras_cv/layers/detectron2_layers.py` to the bottom of the file. - Add the detectron2 layers to the `keras_cv/layers/__init__.py` file. - Add a test for the `ViTDetPatchingAndEmbedding` layer. * Make the backbone functional; unite MLP and MLPBlock * Address David's review comments * Add SAM Task model; make MaskDecoder and PromptEncoder XLA compatible * Remove a stray file * Add docs for the Task model * Add more references [skip ci] * Remove SerializableSequential layer * detectron2 -> vit_det; add SAM presets; fix ViTDet presets * Increse test tolerence for GCB Run
* Start adding components for the segment anything model * SAMLayerNormalization -> keras.layers.LayerNormalization They both behave exactly the same when moving_mean and moving_variance are None and epsilon is 1e-6 * Move the image encoder to detectron2 backbone and fix for tf.keras backend * Address review comments and address saving bug - Use `keras_cv.export_api.keras_cv_export` instead of `keras.saving.register_keras_serializable`. - Add a `SerializableSequential` class to address the saving bug with the `Sequential` model. - Push the helper functions in `keras_cv/layers/detectron2_layers.py` to the bottom of the file. - Add the detectron2 layers to the `keras_cv/layers/__init__.py` file. - Add a test for the `ViTDetPatchingAndEmbedding` layer. * Make the backbone functional; unite MLP and MLPBlock * Address David's review comments * Add SAM Task model; make MaskDecoder and PromptEncoder XLA compatible * Remove a stray file * Add docs for the Task model * Add more references [skip ci] * Remove SerializableSequential layer * detectron2 -> vit_det; add SAM presets; fix ViTDet presets * Increse test tolerence for GCB Run
What does this PR do?
This PR implements the Segment Anything Model in multi-backend Keras.
Fixes #1679
See also #1933
Before submitting
Pull Request section?
to it if that's the case.
Who can review?
@ianstenbit @DavidLandup0